Skip to content

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Dec 18, 2025

No description provided.

@github-actions
Copy link

github-actions bot commented Jan 7, 2026

Review updated until commit 62109c2

Description

  • Implement AllToAll collective communication primitive with full backend integration

  • Add lowering logic to convert AllToAll operations to communication primitives

  • Include AllToAll in communication type handling and layout management

  • Add comprehensive test coverage for multi-device AllToAll functionality

Changes walkthrough

Relevant files
Enhancement
lower_to_communication.cpp
AllToAll lowering and communication detection                       

csrc/host_ir/lower_to_communication.cpp

  • Added lowerToAllToAll function to handle AllToAll operation lowering
  • Updated getCommunicationInfo to detect and handle AllToAll
    communication patterns
  • Modified getCommunicationLayout to include AllToAll in layout
    management
  • Added AllToAll case to convertSingleOpToCommunication
  • +48/-3   
    communication.cpp
    AllToAll backend implementation and communication handling

    csrc/multidevice/communication.cpp

  • Added AllToAll to CommunicationType stream operator output
  • Updated hasRoot to return false for AllToAll (no root process)
  • Updated isReduction to return false for AllToAll (not a reduction)
  • Implemented postAllToAll function with tensor reshaping and backend
    alltoall_base call
  • Added AllToAll case to postSingleCommunication
  • +73/-0   
    communication.h
    AllToAll communication type enum addition                               

    csrc/multidevice/communication.h

    • Added AllToAll to CommunicationType enum definition
    +2/-1     
    Tests
    test_multidevice.py
    AllToAll multi-device test implementation                               

    tests/python/multidevice/test_multidevice.py

  • Added test_alltoall function with comprehensive multi-device testing
  • Tests tensor sharding, communication, and result verification
  • Includes detailed comments explaining AllToAll tensor layout
    transformations
  • Validates correct data movement across device mesh
  • +71/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Tensor Layout Assumptions

    The postAllToAll function makes strong assumptions about tensor layouts and performs multiple tensor operations (reshape, permute, contiguous). While the test validates dimension divisibility, the robustness of these layout assumptions across different input configurations should be verified. Consider adding more comprehensive validation or documentation of the expected tensor format.

    c10::intrusive_ptr<c10d::Work> postAllToAll(
        Communication* communication,
        DeviceIdxType my_device_index,
        c10d::Backend* backend,
        at::Tensor input_tensor,
        at::Tensor output_tensor) {
      NVF_ERROR(
          isTvContiguous(communication->in()),
          "Input tensor is not contiguous: ",
          communication->in(),
          " contiguity: ",
          communication->in()->domain()->getContiguityString());
      NVF_ERROR(
          isTvContiguous(communication->out()),
          "Output tensor is not contiguous: ",
          communication->out(),
          " contiguity: ",
          communication->out()->domain()->getContiguityString());
    
      // input_tv = [DIDx(d), n/d, m, ...]
      // output_tv = [n, DIDx(d), m/d, ...]
      // `n`: gathered dimension
      // `m`: scattered dimension
      // For alltoall correctness, we split `m` and reorder as [DIDx(d), d, n/d,
      // m/d, ...] such that alltoall_base splits across the `d` dimension.
    
      int64_t d = communication->team_size();
      auto input_sizes = input_tensor.sizes();
    
      NVF_CHECK(
          input_sizes.at(1) % d == 0,
          "Scattered dimension must be divisible by the team size");
    
      std::vector<int64_t> input_reshape_sizes(
          input_sizes.begin(), input_sizes.end());
      input_reshape_sizes.at(1) = d;
      input_reshape_sizes.insert(
          input_reshape_sizes.begin() + 2, input_sizes.at(1) / d);
      auto reshaped_input = input_tensor.reshape(input_reshape_sizes);
    
      std::vector<int64_t> permute_dims(input_reshape_sizes.size());
      std::iota(permute_dims.begin(), permute_dims.end(), 0);
      std::swap(permute_dims[0], permute_dims[1]);
    
      auto reordered_input = reshaped_input.permute(permute_dims).contiguous();
    
      auto flattened_input_tensor = viewAsCompact(reordered_input);
      auto flattened_output_tensor = viewAsCompact(output_tensor);
    
      // alltoall_base requires even splits of the input and output tensors.
      auto input_splits = at::tensor_split(
          flattened_input_tensor, communication->team_size(), /*dim=*/0);
      auto output_splits = at::tensor_split(
          flattened_output_tensor, communication->team_size(), /*dim=*/0);
      assertBuffersHaveSameSize(input_splits, output_splits);
    
      std::vector<int64_t> empty_split_sizes;
      return backend->alltoall_base(
          flattened_output_tensor,
          flattened_input_tensor,
          empty_split_sizes,
          empty_split_sizes,
          /*options=*/{});
    }
    Communication Type Detection Logic

    The logic for determining when to use AllToAll vs SendRecv (lines 417-423) appears to be a simple else branch. While this may be correct for the intended use case, the reasoning behind this decision and its correctness across different sharding scenarios should be validated.

    if (c_logical_id == p2c_map.at(p_logical_id)) {
      fill_communication_info(
          CommunicationType::SendRecv, p_logical_id, c_logical_id);
    } else {
      fill_communication_info(
          CommunicationType::AllToAll, nullptr, nullptr);
    }

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 changed the title AllToAll Draft AllToAll lowering Jan 13, 2026
    // For the following communication types, the sharded_id does not have to be
    // outermost in allocation domain. Nonetheless, `tv` still needs to be
    // contiguous and therefore .contiguous() at the beginning of this function.
    // TODO(prmishra): Fix the layout for AllToAll.
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Depending on where we relayout the input/output to be compliant with alltoall requirements, this function and potentially ReorderShardedAxisPass will be affected. I will do it in a following PR once the current PR has been reviewed and we agree on the approach for reordering input/output of alltoall

    @Priya2698 Priya2698 changed the title AllToAll lowering AllToAll implementation Jan 13, 2026
    @Priya2698 Priya2698 marked this pull request as ready for review January 13, 2026 02:09
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 13, 2026

    Greptile Overview

    Greptile Summary

    This PR implements AllToAll collective communication support for nvfuser's multidevice framework. AllToAll enables simultaneous gathering along one dimension and scattering along another, effectively resharding tensors across devices.

    Key Implementation Components

    Communication Type Addition: Adds AllToAll enum value to CommunicationType and integrates it throughout the communication pipeline (operator overloads, type checking functions).

    Core Runtime Implementation: The postAllToAll() function implements the NCCL alltoall_base primitive with tensor reshaping and permutation logic to organize data by destination rank. The implementation handles contiguous tensor validation and flattening.

    Detection Logic: AllToAll is detected when both producer and consumer tensors are sharded but on different logical dimensions (c_logical_id != p2c_map[p_logical_id]). The detection sets both p_sharded_id and c_sharded_id to nullptr, which differs from other communication types.

    Layout Handling: Includes an explicit TODO comment indicating that layout support for AllToAll needs fixes. Currently returns early in getCommunicationLayout() without enforcing specific layout requirements. The test uses permute operations as a workaround for allocation domain limitations.

    Areas Requiring Attention

    1. Documentation Inconsistency: Comments in postAllToAll() describing tensor shapes don't clearly distinguish between logical sharding representation and actual runtime memory layout, which could confuse future maintainers.

    2. Nullptr Sharded IDs: Using nullptr for both sharded IDs in the AllToAll detection creates a fragile dependency where getCommunicationLayout() must explicitly handle AllToAll to avoid dereferencing null pointers.

    3. Layout TODO: The acknowledged TODO for fixing AllToAll layout suggests incomplete implementation that may limit the operation's applicability or require workarounds like explicit permutations.

    4. Limited Validation: The implementation validates that the scattered dimension is divisible by team size but doesn't validate output tensor dimensions or the relationship between gathered/scattered dimensions upfront.

    Confidence Score: 3/5

    • This PR is functionally correct for the tested use case but has implementation limitations and documentation issues that should be addressed
    • The implementation passes its test case and follows existing patterns for other communication operations. However, several concerns reduce confidence: (1) the acknowledged TODO comment for layout fixes indicates incomplete implementation, (2) using nullptr for sharded_ids creates fragile coupling with getCommunicationLayout(), (3) inconsistent comments could mislead future maintainers, and (4) validation could be more comprehensive. The test coverage is adequate but only validates one scenario. These are not critical bugs but technical debt that limits robustness.
    • Pay close attention to csrc/multidevice/communication.cpp (comment accuracy) and csrc/host_ir/lower_to_communication.cpp (nullptr handling and TODO resolution)

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/multidevice/communication.h 5/5 Adds AllToAll enum value to CommunicationType - straightforward addition with no issues
    csrc/multidevice/communication.cpp 2/5 Implements postAllToAll with potential dimension indexing error in validation and reshape logic at lines 644-651
    csrc/host_ir/lower_to_communication.cpp 3/5 Adds AllToAll detection and lowering logic; uses nullptr for sharded_ids which may cause issues; includes TODO for layout fixes
    tests/python/multidevice/test_multidevice.py 4/5 Adds comprehensive test_alltoall with detailed documentation; test validates basic correctness

    Sequence Diagram

    sequenceDiagram
        participant User as User Code
        participant FD as FusionDefinition
        participant Lower as lower_to_communication
        participant Comm as Communication
        participant NCCL as NCCL Backend
        
        User->>FD: define AllToAll fusion with permute ops
        User->>FD: set sharding on different dims
        FD->>Lower: getCommunicationInfo(expr)
        Lower->>Lower: detect p_sharded && c_sharded
        Lower->>Lower: check if c_logical_id != p2c_map[p_logical_id]
        Lower->>Lower: return AllToAll with nullptr sharded_ids
        
        FD->>Lower: convertSingleOpToCommunication()
        Lower->>Lower: getCommunicationLayout() returns early for AllToAll
        Lower->>Lower: lowerToAllToAll()
        Lower->>Comm: create Communication(AllToAll)
        
        User->>FD: execute([input_tensor])
        FD->>Comm: postSingleCommunication()
        Comm->>Comm: postAllToAll()
        Comm->>Comm: check isTvContiguous for input/output
        Comm->>Comm: reshape input to split scattered dim
        Comm->>Comm: permute to move DIDx(d) outermost
        Comm->>Comm: flatten tensors with viewAsCompact
        Comm->>NCCL: alltoall_base(output, input)
        NCCL-->>Comm: Work handle
        Comm-->>User: output tensor with resharded data
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +633 to +638
    // input_tv = [DIDx(d), n/d, m, ...]
    // output_tv = [n, DIDx(d), m/d, ...]
    // `n`: gathered dimension
    // `m`: scattered dimension
    // For alltoall correctness, we split `m` and reorder as [DIDx(d), d, n/d,
    // m/d, ...] such that alltoall_base splits across the `d` dimension.
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The comment describing the input/output tensor shapes appears to be inconsistent with the actual implementation below.

    The comment states:

    • input_tv = [DIDx(d), n/d, m, ...]
    • m: scattered dimension (at position 2)

    But the code at lines 644-651 operates on dimension 1 (input_sizes.at(1)), treating it as the scattered dimension. Based on the test case in test_alltoall, the actual runtime input tensor shape is [n/d, m, k, ...] (after permutation), where:

    • Position 0: n/d (gathered dimension)
    • Position 1: m (scattered dimension)
    • Position 2+: other dimensions

    The comment should be updated to reflect the actual runtime tensor layout rather than the logical sharding representation, or clarify that it's describing logical sharding rather than physical memory layout.

    Comment on lines +417 to +423
    if (c_logical_id == p2c_map.at(p_logical_id)) {
    fill_communication_info(
    CommunicationType::SendRecv, p_logical_id, c_logical_id);
    } else {
    fill_communication_info(
    CommunicationType::AllToAll, nullptr, nullptr);
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Setting both p_sharded_id and c_sharded_id to nullptr for AllToAll communication may cause issues in other parts of the codebase that expect valid IterDomain pointers.

    For example, getCommunicationLayout() (lines 476-521) is called with sharded_id as a parameter, and at line 493 it checks posInDomain(layout.allocation_domain(), sharded_id), which could fail if sharded_id is nullptr.

    The current workaround is that AllToAll is explicitly handled in getCommunicationLayout() at line 489 to return early, but this creates a tight coupling and fragile dependency. Consider either:

    1. Storing the actual producer and consumer sharded IDs for AllToAll (even though they map to different logical dimensions)
    2. Adding explicit nullptr checks in all functions that consume CommunicationInfo
    3. Documenting this assumption clearly in the CommunicationInfo struct definition

    Comment on lines +484 to +489
    // TODO(prmishra): Fix the layout for AllToAll.
    if (type == CommunicationType::Reduce ||
    type == CommunicationType::Allreduce ||
    type == CommunicationType::Broadcast ||
    type == CommunicationType::SendRecv) {
    type == CommunicationType::SendRecv ||
    type == CommunicationType::AllToAll) {
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The TODO comment indicates that the layout for AllToAll still needs to be fixed. This suggests the current implementation may not handle all cases correctly.

    Based on the test comments in test_alltoall, the current approach uses permute operations to avoid non-adjacent splits that would fail stride validation. However, this workaround has limitations:

    1. It requires explicit permutation operations in the fusion definition
    2. It may introduce extra copies when making tensors contiguous
    3. It doesn't expose the true allocation domain structure to the fusion

    Before merging, clarify:

    • What specific layout issues remain to be addressed?
    • Are there known failure cases with the current implementation?
    • Is this TODO blocking or can it be addressed in a follow-up PR?

    Comment on lines +643 to +645
    NVF_CHECK(
    input_sizes.at(1) % d == 0,
    "Scattered dimension must be divisible by the team size");
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The error message says "Scattered dimension must be divisible by the team size" but this check only validates one dimension.

    For a complete AllToAll operation, both the scattered dimension (at position 1 in the runtime tensor) and the gathered dimension (at position 0) should be validated:

    1. The scattered dimension input_sizes.at(1) must be divisible by d (currently checked ✓)
    2. The gathered dimension should result in an output where the gathered size equals d * input_sizes.at(0)

    Consider adding validation for the output tensor dimensions as well to catch shape mismatches early, rather than relying on assertBuffersHaveSameSize at line 668 which only provides a generic error.

    @Priya2698 Priya2698 requested a review from wujingyue January 13, 2026 02:36
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants